import os
import numpy as np
from numpy.core.fromnumeric import size
import matplotlib.pyplot as plt
from mindspore import nn
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as CT
import mindspore.ops.operations as P
import mindspore.ops.functional as F
import mindspore.ops.composite as C
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
                                       _get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore import context
from cells import SigmoidCrossEntropyWithLogits


def create_dataset(data_path,
                   flatten_size,
                   batch_size,
                   repeat_size=1,
                   num_parallel_workers=1):
    mnist_ds = ds.MnistDataset(data_path)
    type_cast_op = CT.TypeCast(mstype.float32)
    onehot_op = CT.OneHot(num_classes=10)

    mnist_ds = mnist_ds.map(input_columns="label",
                            operations=onehot_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="label",
                            operations=type_cast_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=lambda x:
                            ((x - 127.5) / 127.5).astype("float32"),
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=lambda x: (x.reshape((flatten_size, ))),
                            num_parallel_workers=num_parallel_workers)
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds


def one_hot(num_classes=10, arr=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]):
    return np.eye(num_classes)[arr]


class Discriminator(nn.Cell):
    def __init__(self, input_dims, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.fc1 = nn.Dense(input_dims + 10, 256)
        self.fc2 = nn.Dense(256 + 10, 128)
        self.fc3 = nn.Dense(128 + 10, 1)
        self.lrelu = nn.LeakyReLU()
        self.concat = P.Concat(1)

    def construct(self, x, label):
        x = self.concat((x, label))
        x = self.fc1(x)
        x = self.lrelu(x)

        x = self.concat((x, label))
        x = self.fc2(x)
        x = self.lrelu(x)

        x = self.concat((x, label))
        x = self.fc3(x)

        return x


class Generator(nn.Cell):
    def __init__(self, input_dims, output_dim, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.fc1 = nn.Dense(input_dims + 10, 128)
        self.fc2 = nn.Dense(128 + 10, 256)
        self.fc3 = nn.Dense(256 + 10, output_dim)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.concat = P.Concat(1)

    def construct(self, x, label):
        x = self.concat((x, label))
        x = self.fc1(x)
        x = self.relu(x)

        x = self.concat((x, label))
        x = self.fc2(x)
        x = self.relu(x)

        x = self.concat((x, label))
        x = self.fc3(x)
        x = self.tanh(x)

        return x


class DisWithLossCell(nn.Cell):
    def __init__(self, netG, netD, loss_fn, auto_prefix=True):
        super(DisWithLossCell, self).__init__(auto_prefix=auto_prefix)
        self.netG = netG
        self.netD = netD
        self.loss_fn = loss_fn

    def construct(self, real_data, latent_code, label):
        real_out = self.netD(real_data, label)
        real_loss = self.loss_fn(real_out, F.ones_like(real_out))

        fake_data = self.netG(latent_code, label)
        fake_out = self.netD(fake_data, label)
        fake_loss = self.loss_fn(fake_out, F.zeros_like(fake_out))
        loss_D = real_loss + fake_loss

        return loss_D


class GenWithLossCell(nn.Cell):
    def __init__(self, netG, netD, loss_fn, auto_prefix=True):
        super(GenWithLossCell, self).__init__(auto_prefix=auto_prefix)
        self.netG = netG
        self.netD = netD
        self.loss_fn = loss_fn

    def construct(self, latent_code, label):
        fake_data = self.netG(latent_code, label)
        fake_out = self.netD(fake_data, label)
        loss_G = self.loss_fn(fake_out, F.ones_like(fake_out))

        return loss_G


class TrainOneStepCell(nn.Cell):
    def __init__(self,
                 netG,
                 netD,
                 optimizerG: nn.Optimizer,
                 optimizerD: nn.Optimizer,
                 sens=1.0,
                 auto_prefix=True):

        super(TrainOneStepCell, self).__init__(auto_prefix=auto_prefix)
        self.netG = netG
        self.netG.set_grad()
        self.netG.add_flags(defer_inline=True)

        self.netD = netD
        self.netD.set_grad()
        self.netD.add_flags(defer_inline=True)

        self.weights_G = optimizerG.parameters
        self.optimizerG = optimizerG
        self.weights_D = optimizerD.parameters
        self.optimizerD = optimizerD

        self.grad = C.GradOperation(get_by_list=True, sens_param=True)

        self.sens = sens
        self.reducer_flag = False
        self.grad_reducer_G = F.identity
        self.grad_reducer_D = F.identity
        self.parallel_mode = _get_parallel_mode()
        if self.parallel_mode in (ParallelMode.DATA_PARALLEL,
                                  ParallelMode.HYBRID_PARALLEL):
            self.reducer_flag = True
        if self.reducer_flag:
            mean = _get_gradients_mean()
            degree = _get_device_num()
            self.grad_reducer_G = DistributedGradReducer(
                self.weights_G, mean, degree)
            self.grad_reducer_D = DistributedGradReducer(
                self.weights_D, mean, degree)

    def trainD(self, real_data, latent_code, label, loss, loss_net, grad,
               optimizer, weights, grad_reducer):
        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
        grads = grad(loss_net, weights)(real_data, latent_code, label, sens)
        grads = grad_reducer(grads)
        return F.depend(loss, optimizer(grads))

    def trainG(self, latent_code, label, loss, loss_net, grad, optimizer,
               weights, grad_reducer):
        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
        grads = grad(loss_net, weights)(latent_code, label, sens)
        grads = grad_reducer(grads)
        return F.depend(loss, optimizer(grads))

    def construct(self, real_data, latent_code, label):
        loss_D = self.netD(real_data, latent_code, label)
        loss_G = self.netG(latent_code, label)
        d_out = self.trainD(real_data, latent_code, label, loss_D, self.netD,
                            self.grad, self.optimizerD, self.weights_D,
                            self.grad_reducer_D)
        g_out = self.trainG(latent_code, label, loss_G, self.netG, self.grad,
                            self.optimizerG, self.weights_G,
                            self.grad_reducer_G)

        return d_out, g_out


context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
batch_size = 128
input_dim = 100
epochs = 100
lr = 0.001

ds = create_dataset(os.path.join("./data/MNIST_Data", "train"),
                    flatten_size=28 * 28,
                    batch_size=batch_size,
                    num_parallel_workers=2)

netG = Generator(input_dim, 28 * 28)
netD = Discriminator(28 * 28)
loss = SigmoidCrossEntropyWithLogits()
netG_with_loss = GenWithLossCell(netG, netD, loss)
netD_with_loss = DisWithLossCell(netG, netD, loss)
optimizerG = nn.Adam(netG.trainable_params(), lr)
optimizerD = nn.Adam(netD.trainable_params(), lr)

net_train = TrainOneStepCell(netG_with_loss, netD_with_loss, optimizerG,
                             optimizerD)

dataset_helper = DatasetHelper(ds, epoch_num=epochs, dataset_sink_mode=True)
net_train = connect_network_with_dataset(net_train, dataset_helper)

netG.set_train()
netD.set_train()
for epoch in range(epochs):
    step = 1
    for data in dataset_helper:
        imgs = data[0]
        label = data[1]
        latent_code = Tensor(np.random.normal(size=(batch_size, input_dim)),
                             dtype=mstype.float32)
        dout, gout = net_train(imgs, latent_code, label)
        if step % 100 == 0:
            print(
                "epoch {} step {}, d_loss is {:.4f}, g_loss is {:.4f}".format(
                    epoch, step, dout.asnumpy(), gout.asnumpy()))
        step += 1

    for digit in range(10):
        for i in range(4):
            latent_code = Tensor(np.random.normal(size=(1, input_dim)),
                                 dtype=mstype.float32)
            label = Tensor(one_hot(arr=[digit]), dtype=mstype.float32)
            gen_imgs = netG(latent_code, label).asnumpy()
            gen_imgs = gen_imgs.reshape((28, 28))
            plt.subplot(10, 4, digit * 4 + i + 1)
            plt.imshow(gen_imgs * 127.5 + 127.5, cmap="gray")
            plt.axis("off")
    plt.savefig("./images/{}.jpg".format(epoch))